import argparse
from pathlib import Path
from tqdm import tqdm

# torch

import torch

from einops import repeat

# vision imports

from PIL import Image
from torchvision.utils import make_grid, save_image

# dalle related classes and utils

# from dalle_pytorch import DiscreteVAE, OpenAIDiscreteVAE, VQGanVAE, DALLE
from dalle_pytorch.dalle_pytorch_ori import DALLE_PG_Discrete, DiscreteVAE, OpenAIDiscreteVAE, VQGanVAE, DALLE_PG, DALLE_PG_Discrete, DiscretePGVAE
from dalle_pytorch.tokenizer import tokenizer, HugTokenizer, YttmTokenizer, ChineseTokenizer
from IPython import embed
import os
from pytorch3d.io import save_ply, load_ply


from geometry_utils import render_pts, rotate_pts, render_pts_with_label
import torch.nn.functional as F
# argument parsing
import sys
sys.path.insert(0, './shape2prog')
from criterion import BatchIoU
from misc import decode_multiple_block, execute_shape_program
import h5py
import numpy as np


import sys
sys.path.insert(0, '/home/tiangel/ShapeGF')
try:
    from evaluation.evaluation_metrics import EMD_CD
    eval_reconstruciton = True
except:  # noqa
    # Skip evaluation
    eval_reconstruciton = False
sys.path.insert(0, '/home/tiangel/PVD')
from datasets.shapenet_data_pc import ShapeNet15kPointClouds
from datasets.shapenet_data_sv import *
from torch.utils.data import DataLoader
from pytorch3d.loss import chamfer_distance
from scipy.optimize import linear_sum_assignment

def normalize_points_torch(points):
    """Normalize point cloud

    Args:
        points (torch.Tensor): (batch_size, num_points, 3)

    Returns:
        torch.Tensor: normalized points

    """
    assert points.dim() == 3 and points.size(2) == 3
    centroid = points.mean(dim=1, keepdim=True)
    points = points - centroid
    norm, _ = points.norm(dim=2, keepdim=True).max(dim=1, keepdim=True)
    new_points = points / norm
    return new_points

parser = argparse.ArgumentParser()

parser.add_argument('--dalle_path', type = str, required = True,
                    help='path to your trained DALL-E')

parser.add_argument('--vqgan_model_path', type=str, default = None,
                   help='path to your trained VQGAN weights. This should be a .ckpt file. (only valid when taming option is enabled)')

parser.add_argument('--vqgan_config_path', type=str, default = None,
                   help='path to your trained VQGAN config. This should be a .yaml file.  (only valid when taming option is enabled)')

parser.add_argument('--pts_path', type = str, default=None,
                    help='your text prompt')

parser.add_argument('--text', type = str, default = None,
                    help='your text prompt')

parser.add_argument('--num_images', type = int, default = 128, required = False,
                    help='number of images')

parser.add_argument('--batch_size', type = int, default = 4, required = False,
                    help='batch size')

parser.add_argument('--top_k', type = float, default = 0.9, required = False,
                    help='top k filter threshold')

parser.add_argument('--outputs_dir', type = str, default = './outputs/dalle_outputs', required = False,
                    help='output directory')

parser.add_argument('--save_name', type = str, default = '1', help = 'KL loss weight')

parser.add_argument('--emd', type = bool, default = True, help = 'KL loss weight')

parser.add_argument('--category', type = str, default = 'chair', help = 'KL loss weight')

parser.add_argument('--bpe_path', type = str,
                    help='path to your huggingface BPE json file')

parser.add_argument('--hug', dest='hug', action = 'store_true')

parser.add_argument('--chinese', dest='chinese', action = 'store_true')

parser.add_argument('--taming', dest='taming', action='store_true')

parser.add_argument('--gentxt', dest='gentxt', action='store_true')


args = parser.parse_args()

# helper fns

def exists(val):
    return val is not None

# tokenizer

if exists(args.bpe_path):
    klass = HugTokenizer if args.hug else YttmTokenizer
    tokenizer = klass(args.bpe_path)
elif args.chinese:
    tokenizer = ChineseTokenizer()

# load DALL-E

dalle_path = Path(os.path.join('./outputs/dalle_models',args.dalle_path))

assert dalle_path.exists(), 'trained DALL-E must exist'

load_obj = torch.load(str(dalle_path))
dalle_params, vae_params, weights, vae_class_name, version = load_obj.pop('hparams'), load_obj.pop('vae_params'), load_obj.pop('weights'), load_obj.pop('vae_class_name', None), load_obj.pop('version', None)

# friendly print

if exists(version):
    print(f'Loading a model trained with DALLE-pytorch version {version}')
else:
    print('You are loading a model trained on an older version of DALL-E pytorch - it may not be compatible with the most recent version')

# load VAE

if args.taming:
    vae = VQGanVAE(args.vqgan_model_path, args.vqgan_config_path)
elif vae_params is not None:
    vae = DiscreteVAE(**vae_params)
else:
    vae = OpenAIDiscreteVAE()

# pgvae = DiscretePGVAE(**pgvae_params)

assert not (exists(vae_class_name) and vae.__class__.__name__ != vae_class_name), f'you trained DALL-E using {vae_class_name} but are trying to generate with {vae.__class__.__name__} - please make sure you are passing in the correct paths and settings for the VAE to use for generation'

# reconstitute DALL-E

# dalle = DALLE_PG(vae = vae, pgvae = pgvae, **dalle_params).cuda()
dalle = DALLE_PG_Discrete(vae = vae, **dalle_params).cuda()

dalle.load_state_dict(weights)

# generate images

image_size = vae.image_size

# texts = args.text.split('|')

def get_mvr_dataset(pc_dataroot, views_root, npoints,category):
    tr_dataset = ShapeNet15kPointClouds(root_dir=pc_dataroot,
        categories=[category], split='train',
        tr_sample_size=npoints,
        te_sample_size=npoints,
        scale=1.,
        normalize_per_shape=False,
        normalize_std_per_axis=False,
        random_subsample=True)
    te_dataset = ShapeNet_Multiview_Points(root_pc=pc_dataroot, root_views=views_root,
                                            cache=os.path.join(pc_dataroot, '../cache'), split='val',
        categories=[category],
        npoints=npoints, sv_samples=200,
        all_points_mean=tr_dataset.all_points_mean,
        all_points_std=tr_dataset.all_points_std,
    )
    return te_dataset

# 405 airplane + 662 chair + 315 car for test
ds = get_mvr_dataset('/home/tiangel/datasets/ShapeNetCore.v2.PC15k', '/home/tiangel/datasets/GenReData/', 
                        # 10000, args.category)
                        2048, args.category)
dl = DataLoader(ds, 1, shuffle = False, drop_last=False)


save_dir = os.path.join('./outputs/dalle_outputs','test'+args.save_name+'_pts_condpts')
if not os.path.exists(save_dir):
    os.mkdir(save_dir)
 # if not os.path.exists(pred_save_dir):
#     os.mkdir(pred_save_dir)
# pc = load_ply('/home/tiangel/datasets/val_pc_10000p_6k/'+'a1d217ba806367cbc13a0d88b632af1d.ply')
# pc_path = os.path.join('/home/tiangel/datasets/val_pc_10000p_6k/', args.pts_path)
# pc_path = os.path.join(args.pts_path)

ori_pts = []
our_pts = []
cd_list = []
with torch.no_grad():
    for i, data in enumerate(dl):
        gt_all = data['test_points']
        x_all = data['sv_points']
        gt_all = gt_all.cuda()
        gt_all = repeat(gt_all, '() n c -> b n c', b = args.num_images)
        gt_all = normalize_points_torch(gt_all)
        for j in range(x_all.shape[1]):
            outputs = []
            inputs_list = []
            for _ in range(int(args.num_images/args.batch_size)):
                inputs = repeat(x_all[0][j].unsqueeze(0), '() n c -> b n c', b = args.batch_size)
                inputs = normalize_points_torch(inputs.cuda())
                inputs_list.append(inputs)
                outputs.append(dalle.generate_pts_cond_pts(inputs, filter_thres = args.top_k))
            outputs = torch.cat(outputs)
            inputs = torch.cat(inputs_list)
            
            cd_dis = chamfer_distance(gt_all, outputs, batch_reduction=None)[0]
            # cd_dis = chamfer_distance(inputs, outputs, batch_reduction=None)[0]
            min_cddis = torch.min(cd_dis)
            cd_list.append(min_cddis)
            min_idx = torch.argmin(cd_dis)
            # ori_pts.append(gt_all)
            ori_pts.append(gt_all[min_idx].unsqueeze(0))
            our_pts.append(outputs[min_idx].unsqueeze(0))
            print(i,j,min_cddis)

outfile = h5py.File(os.path.join(save_dir, args.category+'.h5'), 'w')            
outfile['our_pts'] = torch.cat(our_pts, dim=0).cpu().numpy()
outfile['ori_pts'] = torch.cat(ori_pts, dim=0).cpu().numpy()
outfile.close()

torch.set_printoptions(precision=7)
cd_dis = chamfer_distance(torch.cat(our_pts, dim=0), torch.cat(ori_pts, dim=0))[0]
print('cd_dis', cd_dis)

ori_pts = torch.cat(ori_pts, dim=0)
our_pts = torch.cat(our_pts, dim=0)

if args.emd:
    emd_dis = []
    dim = ori_pts.shape[1]
    for i in range(ori_pts.shape[0]):
    # for i in range(1):
        if i % 100 == 0:
            print('emd',i)
        q1 = ori_pts[i].cpu().numpy()
        q2 = our_pts[i].cpu().numpy()
        t1 = np.repeat(q1,dim,axis=0).reshape(dim,dim,3)
        t2 = np.swapaxes(np.repeat(q2,dim,axis=0).reshape(dim,dim,3), 0, 1)
        diff = t1-t2
        matrix = diff[:,:,0]*diff[:,:,0]+diff[:,:,1]*diff[:,:,1]+diff[:,:,2]*diff[:,:,2]
        row_ind, col_ind = linear_sum_assignment(matrix)
        diff2=q1 - q2[col_ind]
        # diff2 = q1 - q2
        emd_dis.append(np.mean(np.sqrt(diff2[:,0]*diff2[:,0]+diff2[:,1]*diff2[:,1]+diff2[:,2]*diff2[:,2])))
print('emd_dis:', np.mean(np.array(emd_dis)))

embed()
exit()



for pc_num, pc in enumerate(ori_shapes[:100]):
    points = normalize_points_torch(torch.Tensor(pc).unsqueeze(0)).cuda()
    points = repeat(points, '() n c -> b n c', b = args.num_images)

    out_pgms = []
    out_params = []
    for pts_chunk in tqdm(points.split(args.batch_size), desc = f'generating images for - {args.category, pc_num}'):
        output = dalle.generate_pgs(pts_chunk, filter_thres = args.top_k)
        out_pgms.append(output[0])
        out_params.append(output[1])

    out_pgms = torch.cat(out_pgms)
    out_params = torch.cat(out_params)
    for i in range(out_pgms.shape[0]):
        out_pgm_cur = out_pgms[i]
        less_zero = torch.sum(out_pgm_cur < 0)
        be_up = torch.sum(out_pgm_cur > 20)
        if less_zero + be_up > 0:
            print('output:%d, less_zero:%d, bigger_up:%d'%(i, less_zero, be_up))
        out_pgms[i][out_pgm_cur < 0] = 0
        out_pgms[i][out_pgm_cur >20] = 20

    for i in range(out_params.shape[0]):
        out_param_cur = out_params[i]
        less_zero = torch.sum(out_param_cur < -27)
        be_up = torch.sum(out_param_cur > 29)
        if less_zero + be_up > 0:
            print('output:%d, less_zero:%d, bigger_up:%d'%(i, less_zero, be_up))
        out_params[i][out_param_cur < -27] = -27
        out_params[i][out_param_cur > 29] = 29

    out_pgms = out_pgms.reshape(-1,10,3).cpu().numpy()
    out_params = out_params.reshape(-1, 10, 3 ,7).cpu().numpy()
    res = []
    num_shapes = out_pgms.shape[0]
    for i in range(num_shapes):
        try:
            data = execute_shape_program(out_pgms[i], out_params[i])
        except:
            print('render program wrong', pc_num, i)
            continue
        res.append(data.reshape((1, 32, 32, 32)))
    if len(res) == 0:
        print('big error')
        continue
    res = np.concatenate(res, axis=0)
    target = np.transpose(np.tile(np.expand_dims(ori_voxels[pc_num],axis=-1),res.shape[0]),[3,0,1,2])
    ious = BatchIoU(res, target)
    IoU_list.append(np.max(ious))
    gen_shapes.append(res[np.argmax(ious)].reshape(1, 32, 32, 32))
gen_shapes = np.concatenate(gen_shapes, axis=0) 
IoU_list = np.array(IoU_list)
print("Mean IoU: {:.3f}".format(IoU_list.mean()))

output_file = os.path.join(pred_save_dir, args.category+'.h5')
f_train = h5py.File(output_file, 'w')
f_train['shape'] = gen_shapes
f_train['iou'] = IoU_list
f_train.close()


if eval_reconstruciton:
    for i in range(int(ori_pts.shape[0])):
        rec_res = EMD_CD(ori_pts[i].unsqueeze(0).repeat([20, 1, 1]).cuda(), our_pts[i*20:(i+1)*20].cuda(), 20, reduced=False)
        print(i, 'CD:', rec_res['MMD-CD'].mean(), 'EMD:', rec_res['MMD-EMD'].mean())
        cd_dis.append(rec_res['MMD-CD'])
        emd_dis.append(rec_res['MMD-EMD'])
else:
    print('eval_reconstruciton is false')
    embed()
exit()

for j, text in tqdm(enumerate(texts)):
    if args.gentxt:
        text_tokens, gen_texts = dalle.generate_texts(tokenizer, text=text, filter_thres = args.top_k)
        text = gen_texts[0]
    else:
        text_tokens = tokenizer.tokenize([text], dalle.text_seq_len).cuda()

    text_tokens = repeat(text_tokens, '() n -> b n', b = args.num_images)

    outputs = []

    for text_chunk in tqdm(text_tokens.split(args.batch_size), desc = f'generating images for - {text}'):
        output = dalle.generate_images(text_chunk, filter_thres = args.top_k)
        outputs.append(output)

    outputs = torch.cat(outputs)

    # save all images

    file_name = text 
    outputs_dir = Path(args.outputs_dir) / file_name.replace(' ', '_')[:(100)]
    outputs_dir.mkdir(parents = True, exist_ok = True)

    for i, image in tqdm(enumerate(outputs), desc = 'saving images'):
        save_ply(os.path.join(outputs_dir,'%04d.ply'%i),image)
        # save_image(image, outputs_dir / f'{i}.jpg', normalize=True)
        with open(outputs_dir / 'caption.txt', 'w') as f:
            f.write(file_name)

    print(f'created {args.num_images} images at "{str(outputs_dir)}"')
